{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## 使用TorchText库完成文本分类\n",
    "\n",
    " * Access to the raw data as an iterator\n",
    " * Build data processing pipeline to convert the raw text strings into torch.Tensor that can be used to train the model\n",
    " * Shuffle and iterate the data with torch.utils.data.DataLoader\n",
    "\n",
    "使用 AG_NEWS 新闻数据集，做主题分类"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "from torchtext.datasets import AG_NEWS\n",
    "from torchtext.data.utils import get_tokenizer\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "tokenizer = get_tokenizer('basic_english')\n",
    "print('torch:', torch.__version__)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### 1 准备数据集"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# 使用 build_vocab_from_iterator 预处理\n",
    "\n",
    "from torchtext.vocab import build_vocab_from_iterator\n",
    "train_iter = AG_NEWS(split='train')\n",
    "def yield_tokens(data_iter):\n",
    "    for _, text in data_iter:\n",
    "        yield tokenizer(text)\n",
    "\n",
    "vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=[\"<unk>\"])\n",
    "vocab.set_default_index(vocab[\"<unk>\"])\n",
    "VOCAB_SIZE = len(vocab)\n",
    "print(f'VOCAB_SIZE: {VOCAB_SIZE}')"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Prepare data processing pipelines\n",
    "text_pipeline = lambda x: vocab(tokenizer(x))\n",
    "label_pipeline = lambda x: int(x) - 1\n",
    "\n",
    "def collate_batch(batch):\n",
    "    label_list, text_list, offsets = [], [], [0]\n",
    "    for (_label, _text) in batch:\n",
    "         label_list.append(label_pipeline(_label))\n",
    "         processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)\n",
    "         text_list.append(processed_text)\n",
    "         offsets.append(processed_text.size(0))\n",
    "    label_list = torch.tensor(label_list, dtype=torch.int64)\n",
    "    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)\n",
    "    text_list = torch.cat(text_list)\n",
    "    return label_list.to(device), text_list.to(device), offsets.to(device)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "#  划分训练，验证集，\n",
    "from torch.utils.data.dataset import random_split\n",
    "from torchtext.data.functional import to_map_style_dataset\n",
    "# Hyperparameters\n",
    "EPOCHS = 10 # epoch\n",
    "LR = 5  # learning rate\n",
    "BATCH_SIZE = 64 # batch size for training\n",
    "\n",
    "total_accu = None\n",
    "train_iter, test_iter = AG_NEWS()\n",
    "\n",
    "train_dataset = to_map_style_dataset(train_iter)\n",
    "test_dataset = to_map_style_dataset(test_iter)\n",
    "num_train = int(len(train_dataset) * 0.95)\n",
    "split_train_, split_valid_ = random_split(train_dataset, [num_train, len(train_dataset) - num_train])\n",
    "\n",
    "train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_batch)\n",
    "valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_batch)\n",
    "test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_batch)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### 2 训练\n",
    "1. 定义模型:  nn.EmbeddingBag layer plus a linear layer for the classification purpose.\n",
    "2. 训练和评估代码\n",
    "3. 开始训练"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "\n",
    "class TextClassificationModel(nn.Module):\n",
    "\n",
    "    def __init__(self, vocab_size, embed_dim, num_class):\n",
    "        super(TextClassificationModel, self).__init__()\n",
    "        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)\n",
    "        self.fc = nn.Linear(embed_dim, num_class)\n",
    "        self.init_weights()\n",
    "\n",
    "    def init_weights(self):\n",
    "        initrange = 0.5\n",
    "        self.embedding.weight.data.uniform_(-initrange, initrange)\n",
    "        self.fc.weight.data.uniform_(-initrange, initrange)\n",
    "        self.fc.bias.data.zero_()\n",
    "\n",
    "    def forward(self, text, offsets):\n",
    "        embedded = self.embedding(text, offsets)\n",
    "        return self.fc(embedded)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "EMBEDDING_SIZE = 64\n",
    "NUM_CLASS = 4\n",
    "model = TextClassificationModel(VOCAB_SIZE, EMBEDDING_SIZE, NUM_CLASS).to(device)\n",
    "\n",
    "\n",
    "\n",
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=LR)\n",
    "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)\n",
    "\n",
    "print(f'NUM_CLASS:{NUM_CLASS}, VOCAB_SIZE:{VOCAB_SIZE}, EMBEDDING_SIZE:{EMBEDDING_SIZE}')"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "def train(dataloader):\n",
    "    model.train()\n",
    "    total_acc, total_count = 0, 0\n",
    "    log_interval = 500\n",
    "    start_time = time.time()\n",
    "\n",
    "    for idx, (label, text, offsets) in enumerate(dataloader):\n",
    "        optimizer.zero_grad()\n",
    "        predicted_label = model(text, offsets)\n",
    "        loss = criterion(predicted_label, label)\n",
    "        loss.backward()\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)\n",
    "        optimizer.step()\n",
    "        total_acc += (predicted_label.argmax(1) == label).sum().item()\n",
    "        total_count += label.size(0)\n",
    "        if idx % log_interval == 0 and idx > 0:\n",
    "            elapsed = time.time() - start_time\n",
    "            print('| epoch {:3d} | {:5d}/{:5d} batches '\n",
    "                  '| accuracy {:8.3f}'.format(epoch, idx, len(dataloader),\n",
    "                                              total_acc/total_count))\n",
    "            total_acc, total_count = 0, 0\n",
    "            start_time = time.time()\n",
    "\n",
    "def evaluate(dataloader):\n",
    "    model.eval()\n",
    "    total_acc, total_count = 0, 0\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for idx, (label, text, offsets) in enumerate(dataloader):\n",
    "            predicted_label = model(text, offsets)\n",
    "            loss = criterion(predicted_label, label)\n",
    "            total_acc += (predicted_label.argmax(1) == label).sum().item()\n",
    "            total_count += label.size(0)\n",
    "    return total_acc/total_count"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# 启动训练\n",
    "for epoch in range(1, EPOCHS + 1):\n",
    "    epoch_start_time = time.time()\n",
    "    train(train_dataloader)\n",
    "    accu_val = evaluate(valid_dataloader)\n",
    "    if total_accu is not None and total_accu > accu_val:\n",
    "      scheduler.step()\n",
    "    else:\n",
    "       total_accu = accu_val\n",
    "    print('-' * 59)\n",
    "    print('| end of epoch {:3d} | time: {:5.2f}s | '\n",
    "          'valid accuracy {:8.3f} '.format(epoch,time.time() - epoch_start_time,accu_val))\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# 启动测试\n",
    "print('Checking the results of test dataset.')\n",
    "accu_test = evaluate(test_dataloader)\n",
    "print('test accuracy {:8.3f}'.format(accu_test))\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# 随机文本，测试\n",
    "\n",
    "ag_news_label = {1: \"World\",\n",
    "                 2: \"Sports\",\n",
    "                 3: \"Business\",\n",
    "                 4: \"Sci/Tec\"}\n",
    "\n",
    "def predict(text, text_pipeline):\n",
    "    with torch.no_grad():\n",
    "        text = torch.tensor(text_pipeline(text))\n",
    "        output = model(text, torch.tensor([0]))\n",
    "        return output.argmax(1).item() + 1\n",
    "\n",
    "ex_text_str = \"Exports of mechanical and electrical products jumped 23.8 percent year-on-year to 7.98 trillion yuan in the first eight months of this year, \" \\\n",
    "              \"accounting for 58.8 percent of the nation's total export value,\" \\\n",
    "              \" according to the General Administration of Customs.\"\n",
    "\n",
    "model = model.to(\"cpu\")\n",
    "\n",
    "print(\"This is a %s news\" %ag_news_label[predict(ex_text_str, text_pipeline)])"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "name": "pycharm-cf2a9c60",
   "language": "python",
   "display_name": "PyCharm (machine-learning-pytorch)"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}